import logging
from typing import Any, List, Mapping, Optional
import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_community.llms.utils import enforce_stop_tokens
from configs.server_info import LLMsConfig
from utils.singleton import singleton_llm

logger = logging.getLogger(__name__)

chatglm3_config = LLMsConfig(conf_type='llms').llms['chatglm3-6b']

@singleton_llm
class ChatGLM3(LLM):
    
    endpoint_url: str = f"{chatglm3_config['host']}:{chatglm3_config['port']}"
    model_kwargs: Optional[dict] = None
    max_token: int = 80000
    temperature: float = 0.1
    history: List[List] = []
    top_p: float = 0.9
    with_history: bool = False
    do_sample: bool = True

    @property
    def _llm_type(self) -> str:
        return "ChatGLM3_LLM"

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        _model_kwargs = self.model_kwargs or {}
        return {
            **{"endpoint_url": self.endpoint_url},
            **{"model_kwargs": _model_kwargs},
        }

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call out to a ChatGLM LLM inference endpoint.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                response = chatglm_llm("Who are you?")
        """

        _model_kwargs = self.model_kwargs or {}

        # HTTP headers for authorization
        headers = {"Content-Type": "application/json"}
        
        chat_messages = [
        {
            "role": "system",
            "content": "从现在开始扮演一个专业人士和我对话",
        }]
        chat_messages.append({"role": "user", "content": prompt})
    
        
        data = {
            "model": "chatglm3", # 模型名称
            "messages": chat_messages, # 会话历史
            "stream": False, # 是否流式响应
            "max_tokens": 2000, # 最多生成字数
            "temperature": 0.8, # 温度
            "top_p": 0.8, # 采样概率
        }
        # call api
        try:
            response = requests.post(f"http://{self.endpoint_url}/v1/chat/completions", headers=headers, json=data, stream=False)
        except requests.exceptions.RequestException as e:
            raise ValueError(f"Error raised by inference endpoint: {e}")

        if response.status_code != 200:
            raise ValueError(f"Failed with response: {response}")

        try:
            parsed_response = response.json()
            text = parsed_response.get("choices", [{}])[0].get("message", "").get("content", "")

        except requests.exceptions.JSONDecodeError as e:
            raise ValueError(
                f"Error raised during decoding response from inference endpoint: {e}."
                f"\nResponse: {response.text}"
            )

        if stop is not None:
            text = enforce_stop_tokens(text, stop)
        if self.with_history:
            self.history = self.history + [[None, parsed_response["response"]]]
        return text